# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for test."""
import numpy as np

def get_init_params():
    # Generate initialization parameters
    np.random.seed(42)
    return {
        "max_seq_len": 5,
        "offset": 1,
    }

def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    output_1 = np.array(
        [[[[1.00000000e+00, 5.62341332e-01, 3.16227764e-01, 1.77827939e-01,
            1.00000001e-01, 5.62341288e-02, 3.16227786e-02, 1.77827943e-02,
            9.99999978e-03, 5.62341325e-03, 3.16227786e-03, 1.77827943e-03,
            1.00000005e-03, 5.62341302e-04, 3.16227786e-04, 1.77827940e-04,
            1.00000000e+00, 5.62341332e-01, 3.16227764e-01, 1.77827939e-01,
            1.00000001e-01, 5.62341288e-02, 3.16227786e-02, 1.77827943e-02,
            9.99999978e-03, 5.62341325e-03, 3.16227786e-03, 1.77827943e-03,
            1.00000005e-03, 5.62341302e-04, 3.16227786e-04, 1.77827940e-04]]],

         [[[2.00000000e+00, 1.12468266e+00, 6.32455528e-01, 3.55655879e-01,
            2.00000003e-01, 1.12468258e-01, 6.32455572e-02, 3.55655886e-02,
            1.99999996e-02, 1.12468265e-02, 6.32455572e-03, 3.55655886e-03,
            2.00000009e-03, 1.12468260e-03, 6.32455572e-04, 3.55655880e-04,
            2.00000000e+00, 1.12468266e+00, 6.32455528e-01, 3.55655879e-01,
            2.00000003e-01, 1.12468258e-01, 6.32455572e-02, 3.55655886e-02,
            1.99999996e-02, 1.12468265e-02, 6.32455572e-03, 3.55655886e-03,
            2.00000009e-03, 1.12468260e-03, 6.32455572e-04, 3.55655880e-04]]],

         [[[3.00000000e+00, 1.68702400e+00, 9.48683262e-01, 5.33483803e-01,
            3.00000012e-01, 1.68702394e-01, 9.48683321e-02, 5.33483848e-02,
            2.99999993e-02, 1.68702397e-02, 9.48683359e-03, 5.33483829e-03,
            3.00000003e-03, 1.68702391e-03, 9.48683359e-04, 5.33483806e-04,
            3.00000000e+00, 1.68702400e+00, 9.48683262e-01, 5.33483803e-01,
            3.00000012e-01, 1.68702394e-01, 9.48683321e-02, 5.33483848e-02,
            2.99999993e-02, 1.68702397e-02, 9.48683359e-03, 5.33483829e-03,
            3.00000003e-03, 1.68702391e-03, 9.48683359e-04, 5.33483806e-04]]],

         [[[4.00000000e+00, 2.24936533e+00, 1.26491106e+00, 7.11311758e-01,
            4.00000006e-01, 2.24936515e-01, 1.26491114e-01, 7.11311772e-02,
            3.99999991e-02, 2.24936530e-02, 1.26491114e-02, 7.11311772e-03,
            4.00000019e-03, 2.24936521e-03, 1.26491114e-03, 7.11311761e-04,
            4.00000000e+00, 2.24936533e+00, 1.26491106e+00, 7.11311758e-01,
            4.00000006e-01, 2.24936515e-01, 1.26491114e-01, 7.11311772e-02,
            3.99999991e-02, 2.24936530e-02, 1.26491114e-02, 7.11311772e-03,
            4.00000019e-03, 2.24936521e-03, 1.26491114e-03, 7.11311761e-04]]],

         [[[5.00000000e+00, 2.81170654e+00, 1.58113885e+00, 8.89139712e-01,
            5.00000000e-01, 2.81170636e-01, 1.58113897e-01, 8.89139697e-02,
            4.99999970e-02, 2.81170662e-02, 1.58113893e-02, 8.89139716e-03,
            5.00000035e-03, 2.81170662e-03, 1.58113893e-03, 8.89139716e-04,
            5.00000000e+00, 2.81170654e+00, 1.58113885e+00, 8.89139712e-01,
            5.00000000e-01, 2.81170636e-01, 1.58113897e-01, 8.89139697e-02,
            4.99999970e-02, 2.81170662e-02, 1.58113893e-02, 8.89139716e-03,
            5.00000035e-03, 2.81170662e-03, 1.58113893e-03, 8.89139716e-04]]]])
    mscale_1 = np.array([1.0])
    return {
        "output_1": output_1,
        "mscale_1": mscale_1
    }

def get_gpu_datas() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    output_1 = np.array(
        [[[[1.00000000e+00, 5.62341332e-01, 3.16227764e-01, 1.77827939e-01,
            1.00000001e-01, 5.62341288e-02, 3.16227786e-02, 1.77827943e-02,
            9.99999978e-03, 5.62341325e-03, 3.16227786e-03, 1.77827943e-03,
            1.00000005e-03, 5.62341302e-04, 3.16227786e-04, 1.77827940e-04,
            1.00000000e+00, 5.62341332e-01, 3.16227764e-01, 1.77827939e-01,
            1.00000001e-01, 5.62341288e-02, 3.16227786e-02, 1.77827943e-02,
            9.99999978e-03, 5.62341325e-03, 3.16227786e-03, 1.77827943e-03,
            1.00000005e-03, 5.62341302e-04, 3.16227786e-04, 1.77827940e-04]]],

         [[[2.00000000e+00, 1.12468266e+00, 6.32455528e-01, 3.55655879e-01,
            2.00000003e-01, 1.12468258e-01, 6.32455572e-02, 3.55655886e-02,
            1.99999996e-02, 1.12468265e-02, 6.32455572e-03, 3.55655886e-03,
            2.00000009e-03, 1.12468260e-03, 6.32455572e-04, 3.55655880e-04,
            2.00000000e+00, 1.12468266e+00, 6.32455528e-01, 3.55655879e-01,
            2.00000003e-01, 1.12468258e-01, 6.32455572e-02, 3.55655886e-02,
            1.99999996e-02, 1.12468265e-02, 6.32455572e-03, 3.55655886e-03,
            2.00000009e-03, 1.12468260e-03, 6.32455572e-04, 3.55655880e-04]]],

         [[[3.00000000e+00, 1.68702400e+00, 9.48683262e-01, 5.33483803e-01,
            3.00000012e-01, 1.68702394e-01, 9.48683321e-02, 5.33483848e-02,
            2.99999993e-02, 1.68702397e-02, 9.48683359e-03, 5.33483829e-03,
            3.00000003e-03, 1.68702391e-03, 9.48683359e-04, 5.33483806e-04,
            3.00000000e+00, 1.68702400e+00, 9.48683262e-01, 5.33483803e-01,
            3.00000012e-01, 1.68702394e-01, 9.48683321e-02, 5.33483848e-02,
            2.99999993e-02, 1.68702397e-02, 9.48683359e-03, 5.33483829e-03,
            3.00000003e-03, 1.68702391e-03, 9.48683359e-04, 5.33483806e-04]]],

         [[[4.00000000e+00, 2.24936533e+00, 1.26491106e+00, 7.11311758e-01,
            4.00000006e-01, 2.24936515e-01, 1.26491114e-01, 7.11311772e-02,
            3.99999991e-02, 2.24936530e-02, 1.26491114e-02, 7.11311772e-03,
            4.00000019e-03, 2.24936521e-03, 1.26491114e-03, 7.11311761e-04,
            4.00000000e+00, 2.24936533e+00, 1.26491106e+00, 7.11311758e-01,
            4.00000006e-01, 2.24936515e-01, 1.26491114e-01, 7.11311772e-02,
            3.99999991e-02, 2.24936530e-02, 1.26491114e-02, 7.11311772e-03,
            4.00000019e-03, 2.24936521e-03, 1.26491114e-03, 7.11311761e-04]]],

         [[[5.00000000e+00, 2.81170654e+00, 1.58113885e+00, 8.89139712e-01,
            5.00000000e-01, 2.81170636e-01, 1.58113897e-01, 8.89139697e-02,
            4.99999970e-02, 2.81170662e-02, 1.58113893e-02, 8.89139716e-03,
            5.00000035e-03, 2.81170662e-03, 1.58113893e-03, 8.89139716e-04,
            5.00000000e+00, 2.81170654e+00, 1.58113885e+00, 8.89139712e-01,
            5.00000000e-01, 2.81170636e-01, 1.58113897e-01, 8.89139697e-02,
            4.99999970e-02, 2.81170662e-02, 1.58113893e-02, 8.89139716e-03,
            5.00000035e-03, 2.81170662e-03, 1.58113893e-03, 8.89139716e-04]]]])
    mscale_1 = np.array([1.0])
    return {
        "output_1": output_1,
        "mscale_1": mscale_1
    }

GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_datas()
